﻿#include "ColliderDefinitions.cginc"
#include "ContactHandling.cginc"
#include "DistanceFunctions.cginc"
#include "Simplex.cginc"
#include "Bounds.cginc"
#include "SolverParameters.cginc"
#include "Optimization.cginc"

#pragma kernel GenerateContacts

StructuredBuffer<float4> positions;
StructuredBuffer<quaternion> orientations;
StructuredBuffer<float4> principalRadii;
StructuredBuffer<float4> velocities;

StructuredBuffer<int> simplices;

StructuredBuffer<transform> transforms;
StructuredBuffer<shape> shapes;

StructuredBuffer<uint2> contactPairs;
StructuredBuffer<int> contactOffsetsPerType;

RWStructuredBuffer<contact> contacts;
RWStructuredBuffer<uint> dispatchBuffer;

StructuredBuffer<transform> worldToSolver;

uint maxContacts; 


[numthreads(128, 1, 1)]
void GenerateContacts (uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;

    // entry #11 in the dispatch buffer is the amount of pairs for the first shape type.
    if (i >= dispatchBuffer[11 + 4*BOX_SHAPE]) return; 
    
    uint count = contacts.IncrementCounter();
    if (count < maxContacts)
    {
        int firstPair = contactOffsetsPerType[BOX_SHAPE];
        int simplexIndex = contactPairs[firstPair + i].x;
        int colliderIndex = contactPairs[firstPair + i].y;

        contact c = (contact)0;

        Box boxShape;
        boxShape.colliderToSolver = worldToSolver[0].Multiply(transforms[colliderIndex]);
        boxShape.s = shapes[colliderIndex];

        int simplexSize;
        int simplexStart = GetSimplexStartAndSize(simplexIndex, simplexSize);

        float4 simplexBary = BarycenterForSimplexOfSize(simplexSize);
        float4 simplexPoint;

        SurfacePoint surfacePoint = Optimize(boxShape, positions, orientations, principalRadii,
                                             simplices, simplexStart, simplexSize, simplexBary, simplexPoint, surfaceCollisionIterations, surfaceCollisionTolerance);

        c.pointB = surfacePoint.pos;
        c.normal = surfacePoint.normal * boxShape.s.isInverted();
        c.pointA = simplexBary;
        c.bodyA = simplexIndex;
        c.bodyB = colliderIndex;

        contacts[count] = c;
        
        InterlockedMax(dispatchBuffer[0],(count + 1) / 128 + 1);
        InterlockedMax(dispatchBuffer[3], count + 1);
    }
}